import gc
import pickle
import time
from collections import defaultdict

import numpy as np
from sklearn.datasets import make_regression
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import Ridge
from sklearn.linear_model import SGDRegressor
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVR
from sklearn.utils import shuffle
from sklearn.datasets import make_s_curve
import scipy


def main(scal):
    global n_train, n_test, n_features
    if scal == 'general':
        n_train, n_test, n_features = 20000, 2000, 80
    else:
        n_train, n_test, n_features = 8000, 800, 60

    # #############################################################################
    # Platform-Testing bulk/atomic prediction speed for various regressors
    configuration_linear = {
        'n_train': n_train,
        'n_test': n_test,
        'n_features': n_features,
        'estimators': [
            {'name': 'sgdregressor-{}'.format(scal),
             'instance': SGDRegressor(penalty='elasticnet', alpha=0.01,
                                      l1_ratio=0.25, tol=1e-4),
             'complexity_label': 'non-zero coefficients',
             'complexity_computer': lambda clf: np.count_nonzero(clf.coef_)},
            {'name': 'randomforestregressor-{}'.format(scal),
             'instance': RandomForestRegressor(),
             'complexity_label': 'estimators',
             'complexity_computer': lambda clf: clf.n_estimators},
            {'name': 'svr-{}'.format(scal),
             'instance': SVR(kernel='rbf'),
             'complexity_label': 'support vectors',
             'complexity_computer': lambda clf: len(clf.support_vectors_)}
        ]
    }

    train_models(configuration_linear)

    # benchmark(configuration_linear)

    # benchmark n_features influence on prediction speed
    # percentile = 90
    # percentiles = n_feature_influence({'ridge': Ridge()},
    #                                   configuration['n_train'],
    #                                   configuration['n_test'],
    #                                   [100, 250, 500], percentile)
    # plot_n_features_influence(percentiles, percentile)

    # benchmark throughput
    # throughputs = benchmark_throughputs(configuration)
    # plot_benchmark_throughput(throughputs, configuration)



def _not_in_sphinx():
    # Hack to detect whether we are running by the sphinx builder
    return '__file__' in globals()


def atomic_benchmark_estimator(estimator, X_test, verbose=False):
    """Measure runtime prediction of each instance."""
    n_instances = X_test.shape[0]
    runtimes = np.zeros(n_instances, dtype=np.float)
    for i in range(n_instances):
        instance = X_test[[i], :]
        start = time.time()
        estimator.predict(instance)
        runtimes[i] = time.time() - start
        print("inference in {} ".format(runtimes[i]))
    if verbose:
        print("atomic_benchmark runtimes:", min(runtimes), np.percentile(
            runtimes, 50), max(runtimes))
    return runtimes


def bulk_benchmark_estimator(estimator, X_test, n_bulk_repeats, verbose):
    """Measure runtime prediction of the whole input."""
    n_instances = X_test.shape[0]
    runtimes = np.zeros(n_bulk_repeats, dtype=np.float)
    for i in range(n_bulk_repeats):
        start = time.time()
        estimator.predict(X_test)
        runtimes[i] = time.time() - start
    runtimes = np.array(list(map(lambda x: x / float(n_instances), runtimes)))
    if verbose:
        print("bulk_benchmark runtimes:", min(runtimes), np.percentile(
            runtimes, 50), max(runtimes))
    return runtimes


def benchmark_estimator(estimator, X_test, n_bulk_repeats=30, verbose=False):
    """
    Measure runtimes of prediction in both atomic and bulk mode.

    Parameters
    ----------
    estimator : already trained estimator supporting `predict()`
    X_test : test input
    n_bulk_repeats : how many times to repeat when evaluating bulk mode

    Returns
    -------
    atomic_runtimes, bulk_runtimes : a pair of `np.array` which contain the
    runtimes in seconds.

    """
    atomic_runtimes = atomic_benchmark_estimator(estimator, X_test, verbose)
    bulk_runtimes = bulk_benchmark_estimator(estimator, X_test, n_bulk_repeats,
                                             verbose)
    return atomic_runtimes, bulk_runtimes


def generate_dataset(n_train, n_test, n_features, noise=0.1, verbose=False):
    """Generate a regression dataset with the given parameters."""
    if verbose:
        print("generating dataset...")

    X, y, coef = make_regression(n_samples=n_train + n_test,
                                 n_features=n_features, noise=noise, coef=True)

    random_seed = 13
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, train_size=n_train, test_size=n_test, random_state=random_seed)
    X_train, y_train = shuffle(X_train, y_train, random_state=random_seed)

    X_scaler = StandardScaler()
    X_train = X_scaler.fit_transform(X_train)
    X_test = X_scaler.transform(X_test)

    y_scaler = StandardScaler()
    y_train = y_scaler.fit_transform(y_train[:, None])[:, 0]
    y_test = y_scaler.transform(y_test[:, None])[:, 0]

    gc.collect()
    if verbose:
        print("ok")
    return X_train, y_train, X_test, y_test


def boxplot_runtimes(runtimes, pred_type, configuration):
    """
    Plot a new `Figure` with boxplots of prediction runtimes.

    Parameters
    ----------
    runtimes : list of `np.array` of latencies in micro-seconds
    cls_names : list of estimator class names that generated the runtimes
    pred_type : 'bulk' or 'atomic'

    """



def train_models(configuration):
    X_train, y_train, X_test, y_test = generate_dataset(
        configuration['n_train'], configuration['n_test'],
        configuration['n_features'])

    for estimator_conf in configuration['estimators']:
        print("Benchmarking", estimator_conf['instance'])
        estimator_conf['instance'].fit(X_train, y_train)
        gc.collect()
        name = estimator_conf['name']
        with open("../Runtimes/ml-inference-with-data-RT/models/{}".format(name), "wb") as model:
            pickle.dump(estimator_conf['instance'], model)

        with open("../Runtimes/ml-inference-with-data-RT/testData/{}".format(name), "wb") as test_data:
            pickle.dump(X_test, test_data)

def benchmark(configuration):
    """Run the whole benchmark."""

    X_train, y_train, X_test, y_test = generate_dataset(
        configuration['n_train'], configuration['n_test'],
        configuration['n_features'])

    stats = {}

    for estimator_conf in configuration['estimators']:
        name = estimator_conf['name']

        with open("../models/{}".format(name), "rb") as model:
            estimator_conf['instance'] = pickle.load(model)
            a, b = benchmark_estimator(estimator_conf['instance'], X_test)
            stats[estimator_conf['name']] = {'atomic': a, 'bulk': b}

    cls_names = [estimator_conf['name'] for estimator_conf in configuration['estimators']]
    runtimes = [1e6 * stats[clf_name]['atomic'] for clf_name in cls_names]
    boxplot_runtimes(runtimes, 'atomic', configuration)
    runtimes = [1e6 * stats[clf_name]['bulk'] for clf_name in cls_names]
    boxplot_runtimes(runtimes, 'bulk (%d)' % configuration['n_test'],
                     configuration)


# def n_feature_influence(estimators, n_train, n_test, n_features, percentile):
#     """
#     Estimate influence of the number of features on prediction time.
#
#     Parameters
#     ----------
#
#     estimators : dict of (name (str), estimator) to benchmark
#     n_train : nber of training instances (int)
#     n_test : nber of Platform-Testing instances (int)
#     n_features : list of feature-space dimensionality to test (int)
#     percentile : percentile at which to measure the speed (int [0-100])
#
#     Returns:
#     --------
#
#     percentiles : dict(estimator_name,
#                        dict(n_features, percentile_perf_in_us))
#
#     """
#     percentiles = defaultdict(defaultdict)
#     for n in n_features:
#         print("benchmarking with %d features" % n)
#         X_train, y_train, X_test, y_test = generate_dataset(n_train, n_test, n)
#         for cls_name, estimator in estimators.items():
#             estimator.fit(X_train, y_train)
#             gc.collect()
#             runtimes = bulk_benchmark_estimator(estimator, X_test, 30, False)
#             percentiles[cls_name][n] = 1e6 * np.percentile(runtimes,
#                                                            percentile)
#     return percentiles



def benchmark_throughputs(configuration, duration_secs=0.1):
    """benchmark throughput for different estimators."""
    X_train, y_train, X_test, y_test = generate_dataset(
        configuration['n_train'], configuration['n_test'],
        configuration['n_features'])
    throughputs = dict()
    for estimator_config in configuration['estimators']:
        estimator_config['instance'].fit(X_train, y_train)
        start_time = time.time()
        n_predictions = 0
        while (time.time() - start_time) < duration_secs:
            estimator_config['instance'].predict(X_test[[0]])
            n_predictions += 1
        throughputs[estimator_config['name']] = n_predictions / duration_secs
    return throughputs


print(main('general'))
